import typing
from typing import Union

import torch
from torch import Tensor

import torch_geometric.typing
from torch_geometric import is_compiling
from torch_geometric.utils import is_sparse
from torch_geometric.typing import Size, SparseTensor
{% for module in modules %}
from {{module}} import *
{%- endfor %}


{% include "collect.jinja" %}


def propagate(
    self,
    edge_index: Union[Tensor, SparseTensor],
{%- for param in signature.param_dict.values() %}
    {{param.name}}: {{param.type_repr}},
{%- endfor %}
    size: Size = None,
) -> {{signature.return_type_repr}}:

    # Begin Propagate Forward Pre Hook #########################################
    if not torch.jit.is_scripting() and not is_compiling():
        for hook in self._propagate_forward_pre_hooks.values():
            hook_kwargs = dict(
{%- for name in signature.param_dict %}
                {{name}}={{name}},
{%- endfor %}
            )
            res = hook(self, (edge_index, size, hook_kwargs))
            if res is not None:
                edge_index, size, hook_kwargs = res
{%- for name in signature.param_dict %}
                {{name}} = hook_kwargs['{{name}}']
{%- endfor %}
    # End Propagate Forward Pre Hook ###########################################

    mutable_size = self._check_input(edge_index, size)

    # Run "fused" message and aggregation (if applicable).
    fuse = False
    if self.fuse:
        if is_sparse(edge_index):
            fuse = True
        elif not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):
            if self.SUPPORTS_FUSED_EDGE_INDEX and edge_index.is_sorted_by_col:
                fuse = True

    if fuse:

{%- if fuse %}
        # Begin Message and Aggregate Forward Pre Hook #########################
        if not torch.jit.is_scripting() and not is_compiling():
            for hook in self._message_and_aggregate_forward_pre_hooks.values():
                hook_kwargs = dict(
{%- for name in message_and_aggregate_args %}
                    {{name}}={{name}},
{%- endfor %}
                )
                res = hook(self, (edge_index, hook_kwargs))
                if res is not None:
                    edge_index, hook_kwargs = res
{%- for name in message_and_aggregate_args %}
                    {{name}} = hook_kwargs['{{name}}']
{%- endfor %}
        # End Message and Aggregate Forward Pre Hook ##########################

        out = self.message_and_aggregate(
            edge_index,
{%- for name in message_and_aggregate_args %}
            {{name}},
{%- endfor %}
        )

        # Begin Message and Aggregate Forward Hook #############################
        if not torch.jit.is_scripting() and not is_compiling():
            for hook in self._message_and_aggregate_forward_hooks.values():
                hook_kwargs = dict(
{%- for name in message_and_aggregate_args %}
                    {{name}}={{name}},
{%- endfor %}
                )
                res = hook(self, (edge_index, hook_kwargs, ), out)
                out = res if res is not None else out
        # End Message and Aggregate Forward Hook ###############################

        out = self.update(
            out,
{%- for name in update_args %}
            {{name}}={{name}},
{%- endfor %}
        )
{%- else %}
        raise NotImplementedError("'message_and_aggregate' not implemented")
{%- endif %}

    else:

        kwargs = self.{{collect_name}}(
            edge_index,
{%- for name in signature.param_dict %}
            {{name}},
{%- endfor %}
            mutable_size,
        )

        # Begin Message Forward Pre Hook #######################################
        if not torch.jit.is_scripting() and not is_compiling():
            for hook in self._message_forward_pre_hooks.values():
                hook_kwargs = dict(
{%- for name in message_args %}
                    {{name}}=kwargs.{{name}},
{%- endfor %}
                )
                res = hook(self, (hook_kwargs, ))
                hook_kwargs = res[0] if isinstance(res, tuple) else res
                if res is not None:
                    kwargs = CollectArgs(
{%- for name in collect_param_dict %}
{%- if name in message_args %}
                        {{name}}=hook_kwargs['{{name}}'],
{%- else %}
                        {{name}}=kwargs.{{name}},
{%- endif %}
{%- endfor %}
                    )
        # End Message Forward Pre Hook #########################################

        out = self.message(
{%- for name in message_args %}
            {{name}}=kwargs.{{name}},
{%- endfor %}
        )

        # Begin Message Forward Hook ###########################################
        if not torch.jit.is_scripting() and not is_compiling():
            for hook in self._message_forward_hooks.values():
                hook_kwargs = dict(
{%- for name in message_args %}
                    {{name}}=kwargs.{{name}},
{%- endfor %}
                )
                res = hook(self, (hook_kwargs, ), out)
                out = res if res is not None else out
        # End Message Forward Hook #############################################

        # Begin Aggregate Forward Pre Hook #####################################
        if not torch.jit.is_scripting() and not is_compiling():
            for hook in self._aggregate_forward_pre_hooks.values():
                hook_kwargs = dict(
{%- for name in aggregate_args %}
                    {{name}}=kwargs.{{name}},
{%- endfor %}
                )
                res = hook(self, (hook_kwargs, ))
                hook_kwargs = res[0] if isinstance(res, tuple) else res
                if res is not None:
                    kwargs = CollectArgs(
{%- for name in collect_param_dict %}
{%- if name in aggregate_args %}
                        {{name}}=hook_kwargs['{{name}}'],
{%- else %}
                        {{name}}=kwargs.{{name}},
{%- endif %}
{%- endfor %}
                    )
        # End Aggregate Forward Pre Hook #######################################

        out = self.aggregate(
            out,
{%- for name in aggregate_args %}
            {{name}}=kwargs.{{name}},
{%- endfor %}
        )

        # Begin Aggregate Forward Hook #########################################
        if not torch.jit.is_scripting() and not is_compiling():
            for hook in self._aggregate_forward_hooks.values():
                hook_kwargs = dict(
{%- for name in aggregate_args %}
                    {{name}}=kwargs.{{name}},
{%- endfor %}
                )
                res = hook(self, (hook_kwargs, ), out)
                out = res if res is not None else out
        # End Aggregate Forward Hook ###########################################

        out = self.update(
            out,
{%- for name in update_args %}
            {{name}}=kwargs.{{name}},
{%- endfor %}
        )

    # Begin Propagate Forward Hook ############################################
    if not torch.jit.is_scripting() and not is_compiling():
        for hook in self._propagate_forward_hooks.values():
            hook_kwargs = dict(
{%- for name in signature.param_dict %}
                {{name}}={{name}},
{%- endfor %}
            )
            res = hook(self, (edge_index, mutable_size, hook_kwargs), out)
            out = res if res is not None else out
    # End Propagate Forward Hook ##############################################

    return out
